package com.hefan.common.orm.dao;

import com.cat.tiger.util.GlobalConstants;
import com.google.common.base.CaseFormat;
import com.google.common.collect.Maps;
import com.cat.common.entity.Page;
import com.hefan.common.orm.annotation.Entity;
import com.hefan.common.orm.domain.BaseEntity;
import com.hefan.common.orm.annotation.Column;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.core.BeanPropertyRowMapper;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.core.simple.SimpleJdbcInsert;

import javax.annotation.Resource;
import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author ninglijun
 * @author diguage
 */
public class BaseDaoImpl<T extends BaseEntity> implements BaseDao<T> {

    protected Logger logger = LoggerFactory.getLogger(this.getClass());

    private static ConcurrentHashMap<String, Set<Field>> typeFields = new ConcurrentHashMap<String, Set<Field>>();

    private static final String TABLE_FORMATOR = "`%s`";

    @Resource
    protected JdbcTemplate jdbcTemplate;



    @Resource
    protected NamedParameterJdbcTemplate namedParameterJdbcTemplate;

    protected void setJdbcTemplate(JdbcTemplate jdbcTemplate) {
        this.jdbcTemplate = jdbcTemplate;
    }

    protected JdbcTemplate getJdbcTemplate() {
        return jdbcTemplate;
    }

    protected NamedParameterJdbcTemplate getNamedParameterJdbcTemplate() {
        return namedParameterJdbcTemplate;
    }

    protected void setNamedParameterJdbcTemplate(NamedParameterJdbcTemplate namedParameterJdbcTemplate) {
        this.namedParameterJdbcTemplate = namedParameterJdbcTemplate;
    }

    /**
     * 批量查询数据
     *
     * @param sql         SQL语句
     * @param placeholder 占位符
     * @param selector    选择符,比如ID Set
     * @return
     */
    protected List<T> getEntityList(String sql, String placeholder, Set<?> selector) {
        MapSqlParameterSource parameters = new MapSqlParameterSource();
        parameters.addValue(placeholder, selector);
        return getNamedParameterJdbcTemplate().query(sql, parameters, new BeanPropertyRowMapper<T>(getGenericClass()));
    }

    /**
     * 根据类名称或者注解的值获取数据库表名称
     * <p/>
     * 如果注解的值不为空,则优先选择注解的值;
     * <p/>
     * 如果注解的值为空,则将类名称转化为下划线分割的字符返回作为表名称.
     *
     * @return 表名, 格式为: <code>`tableName`</code>
     */
    public String getTableName() {
        Class<T> clazz = getGenericClass();
        Entity entity = clazz.getAnnotation(Entity.class);
        if (StringUtils.isNotBlank(entity.tableName())) {
            return formatTableName(entity.tableName());
        } else {
            return formatTableName(CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, clazz.getSimpleName()));
        }
    }

    /**
     * 格式化表名, 格式后为: <code>`tableName`</code>
     *
     * @param tableName 表名
     * @return 格式化结果
     */
    private String formatTableName(String tableName) {
        return String.format(TABLE_FORMATOR, tableName);
    }

    protected Class<T> getGenericClass() {
        Type t = getClass().getGenericSuperclass();
        while (!(t instanceof ParameterizedType)) {
            t = t.getClass().getGenericSuperclass();
        }
        Type[] params = ((ParameterizedType) t).getActualTypeArguments();
        Class<T> cls = (Class<T>) params[0];
        return cls;
    }

    private Set<Field> getAllFields() {
        Class<?> clazz = getGenericClass();
        Set<Field> fields = typeFields.get(clazz.getName());
        if (fields != null && fields.size() > 0) {
            return fields;
        }
        fields = new HashSet<Field>();
        while (clazz != Object.class) {
            fields.addAll(Arrays.asList(clazz.getDeclaredFields()));
            clazz = clazz.getSuperclass();
        }
        typeFields.put(clazz.getName(), fields);
        return fields;
    }

    private Map<String, String> getOrmColumns() {
        Set<Field> allFields = getAllFields();
        Map<String, String> fieldsRelColumn = new HashMap<String, String>();
        for (Field f : allFields) {
            Column column = f.getAnnotation(Column.class);
            if (column == null) {
                continue;
            }
            fieldsRelColumn.put(f.getName(),
                    getColumnNameByFieldNameOrAnnotationValue(f, column));
        }
        return fieldsRelColumn;
    }

    /**
     * 根据字段名称或者注解的值获取数据库字段名称
     * <p/>
     * 如果注解的值不为空,则优先选择注解的值;
     * <p/>
     * 如果注解的值为空,则将字段名称转化为下划线分割的字符返回.
     *
     * @param field  字段
     * @param column Column注解
     * @return 数据库字段名称
     */
    private String getColumnNameByFieldNameOrAnnotationValue(Field field, Column column) {
        if (StringUtils.isNotBlank(column.value())) {
            return column.value();
        } else {
            return CaseFormat.LOWER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, field.getName());
        }
    }

    private Object getFieldValue(String field, T t) {
        Class<?> clazz = getGenericClass();
        while (clazz != Object.class) {
            try {
                Field f = clazz.getDeclaredField(field);
                f.setAccessible(true);
                if (f != null) {
                    return f.get(t);
                }
            } catch (Exception e) {
            }
            clazz = clazz.getSuperclass();
        }

        return null;
    }

    private T update(T t) {
        Map<String, String> ormColumns = getOrmColumns();
        StringBuilder sql = new StringBuilder("UPDATE").append(" " + getTableName()).append(" SET ");
        List<Object> params = new ArrayList<Object>();
        StringBuilder subSql = new StringBuilder();
        for (Map.Entry<String, String> entry : ormColumns.entrySet()) {
            if (subSql.length() > 0) {
                subSql.append(", ");
            }
            subSql.append(entry.getValue()).append(" = ? ");
            params.add(getFieldValue(entry.getKey(), t));
        }
        sql.append(subSql);
        sql.append(" WHERE id= ?");
        params.add(t.getId());

        getJdbcTemplate().update(sql.toString(), params.toArray());
        return t;
    }

    @Override
    public T save(T t) {
        if (t.getId() > 0) {
            return update(t);
        }
        Map<String, String> ormColumns = getOrmColumns();
        Map<String, Object> parameters = new HashMap<String, Object>();
        for (Map.Entry<String, String> entry : ormColumns.entrySet()) {
            Object fieldVal = getFieldValue(entry.getKey(), t);
            if (fieldVal == null) {
                continue;
            }
            parameters.put(entry.getValue(), fieldVal);
        }

        String[] columns = parameters.keySet().toArray(new String[]{});
        SimpleJdbcInsert simpleJdbcInsert = new SimpleJdbcInsert(getJdbcTemplate().getDataSource())
                .usingColumns(columns)
                .withTableName(getTableName())
                .usingGeneratedKeyColumns("id");
        Number id = simpleJdbcInsert.executeAndReturnKey(parameters);
        return get(id.longValue());
    }

    @Override
    public int saveBackRowNum(T t) {
        if (t.getId() != 0) {
            return 0;
        }
        Map<String, String> ormColumns = getOrmColumns();
        Map<String, Object> parameters = new HashMap<String, Object>();
        for (Map.Entry<String, String> entry : ormColumns.entrySet()) {
            Object fieldVal = getFieldValue(entry.getKey(), t);
            if (fieldVal == null) {
                continue;
            }
            parameters.put(entry.getValue(), fieldVal);
        }

        String[] columns = parameters.keySet().toArray(new String[]{});
        SimpleJdbcInsert simpleJdbcInsert = new SimpleJdbcInsert(getJdbcTemplate().getDataSource())
                .usingColumns(columns)
                .withTableName(getTableName());
        return simpleJdbcInsert.execute(parameters);
    }

    @Override
    public T get(long id) {
        String sql = " SELECT * " +
                " FROM " + getTableName() +
                " WHERE id = ?";
        List<T> list = getJdbcTemplate().query(sql, new Object[]{id}, new BeanPropertyRowMapper<T>(getGenericClass()));

        if (CollectionUtils.isNotEmpty(list)) {
            return list.get(0);
        }
        return null;
    }

    @Override
    public T get(String sql, List<Object> params) {
        List<T> list = getJdbcTemplate().query(sql, params.toArray(), new BeanPropertyRowMapper<T>(getGenericClass()));

        if (CollectionUtils.isNotEmpty(list)) {
            return list.get(0);
        }
        return null;
    }

    @Override
    public Map<Long, T> find(Set<Long> ids) {
        Map<Long, T> result = Maps.newHashMap();
        if (CollectionUtils.isEmpty(ids)) {
            return result;
        }

        String sql = " SELECT * " +
                " FROM " + getTableName() +
                " WHERE id IN ( :ids )";

        List<T> orderList = getEntityList(sql, "ids", ids);

        if (CollectionUtils.isNotEmpty(orderList)) {
            for (T t : orderList) {
                result.put(t.getId(), t);
            }
        }

        return result;
    }

    @Override
    public boolean delete(long id) {
        String sql = " UPDATE " + getTableName() +
                " SET delete_flag = " + GlobalConstants.DELETE_FLAG_YES +
                " WHERE id = ?";
        int rows = getJdbcTemplate().update(sql, id);
        return rows > 0;
    }

    @Override
    public boolean delete(Set<Long> ids) {
        String sql = " UPDATE " + getTableName() +
                " SET delete_flag = " + GlobalConstants.DELETE_FLAG_YES +
                " WHERE id IN ( :ids ) ";

        MapSqlParameterSource parameters = new MapSqlParameterSource();
        parameters.addValue("ids", ids);

        int rows = getNamedParameterJdbcTemplate().update(sql, parameters);

        return rows > 0;
    }

    @Override
    public boolean realDelete(long id) {
        String sql = "delete from " + getTableName() + " where id = ?";
        int rows = getJdbcTemplate().update(sql, id);
        return rows > 0;
    }

    @Override
    public boolean realDelete(Set<Long> ids) {
        String sql = "DELETE FROM " + getTableName() +
                " WHERE id IN ( :ids ) ";

        MapSqlParameterSource parameters = new MapSqlParameterSource();
        parameters.addValue("ids", ids);

        int rows = getNamedParameterJdbcTemplate().update(sql, parameters);

        return rows > 0;
    }

    @Override
    public List<Map<String, Object>> queryMap(String sql, Object... params) {
        return getJdbcTemplate().queryForList(sql, params);
    }

    @Override
    public List<T> query(String sql, Object... params) {
        return getJdbcTemplate().query(sql, params, new BeanPropertyRowMapper<T>(getGenericClass()));
    }

    @Override
    public Page<T> findPage(Page<T> page, String sql, Object... params) {
        String countSql = "SELECT count(1) AS c FROM (" + sql + ") t";
        logger.info(countSql);
        int count = getJdbcTemplate().queryForObject(countSql, params, Integer.class);
        page.setTotalItems(count);
        if (count == 0) {
            return page;
        }

        String pageSql = sql;
        if (StringUtils.isNotBlank(page.getOrderBy()) && StringUtils.isNotBlank(page.getOrder())) {
            pageSql += " ORDER BY " + page.getOrderBy() + " " + page.getOrder() + " ";
        }
        pageSql += " LIMIT " + page.getOffset() + "," + page.getPageSize();

        logger.info(pageSql);
        List<T> result = getJdbcTemplate().query(pageSql, new BeanPropertyRowMapper<T>(getGenericClass()), params);
        page.setResult(result);
        return page;
    }

    @Override
    public Page<Map<String, Object>> findPageMap(Page<Map<String, Object>> page, String sql, Object... params) {
        String countSql = "SELECT count(1) AS c FROM (" + sql + ") t";
        logger.info(countSql);
        int count = getJdbcTemplate().queryForObject(countSql, params, Integer.class);
        page.setTotalItems(count);
        if (count == 0) {
            return page;
        }

        String pageSql = sql;
        if (StringUtils.isNotBlank(page.getOrderBy()) && StringUtils.isNotBlank(page.getOrder())) {
            pageSql += " ORDER BY " + page.getOrderBy() + " " + page.getOrder() + " ";
        }
        pageSql += " LIMIT " + page.getOffset() + "," + page.getPageSize();
        logger.info(pageSql);
        List<Map<String, Object>> result = getJdbcTemplate().queryForList(pageSql, params);
        page.setResult(result);
        return page;
    }
}
